""" Import libraries """
import argparse
import os

import copy
import sys
import time

import numpy as np
from tensorforce.agents import Agent
from tensorforce.environments import Environment
import matplotlib.pyplot as plt
import tensorflow as tf

from model.env_rl_helper import RLModelHelper
from utils.utils_helper import UtilsHelper

""" Parser related """
parser = argparse.ArgumentParser(description="run_file")
parser.add_argument('--seed', default=0, type=int)
parser.add_argument('--respath', default=0, type=int)
args = parser.parse_args()
np.random.seed(args.seed)
tf.random.set_seed(args.seed)

"""
The experiment
"""
online_training_act_and_observe = 50
max_online_updates = 100  # In paper, we have "update<=maxItr". Here we use "update<maxItr" because Python start with 0

""" Load the environment and the RL agent """
if args.respath == 0:  # first round of online implementation
    agent_path = 'results'
    result_path = 'results_online_implementation'
elif args.respath == 1: # second round of online implementation
    agent_path = 'results_online_implementation'
    result_path = 'results_online_implementation_r2'
else:
    print("The argument for respath is invalid, please define it carefully")
    quit()

agent_directory = agent_path+'/seed_{}_train'.format(args.seed)
# agent_directory = 'results/seed_{}_agent_retrained'.format(args.seed)
# initial_states = np.load('data/initial_states_10k_episodes_smallCIS_seed20.npy')

rl_helper = RLModelHelper()  # For plant simulator
env = Environment.create(environment=RLModelHelper, max_episode_timesteps=200)  #todo: one important change
agent = Agent.load(directory=agent_directory, filename='ppo_cstr_0p0001lr_10000episodes')  # Load trained agent

initial_states_presampled = np.load('initial_states_10kepisodes.npy')

""" Initialize simulation """
score_history = []
rewards_history = []
actions_history = []
states_history = []
fail_history = []
bad_state_transitions_history = []
time_history = []
time_failed_step_history = []
n_episodes = 10000
num_sim = 200
bad_epi = 0
retrain_accumulate = 0
giveup = False

""" Test the RL agent """
print('Start simulation')
for _ in range(n_episodes):
    time_one_episode = -time.time()
    # if _ == 105 or _ == 332 or _ == 973 or _ == 1909:     # These are failed episodes for seed 6
    #     print('breakpoint')
    env.actual_reset = True
    states = env.reset()
    states, env.current_state = initial_states_presampled[_,0,:], initial_states_presampled[_,0,:]
    # states = initial_states[_,:]
    # env.current_state = initial_states[_,:]
    rewards_episode = []
    actions_episode = []
    states_episode = [states]
    bad_state_transitions_episode = []
    score = 0
    reset_x = 0  # How many times the agent reset the state.
    optimization_wrong = 0
    internals = agent.initial_internals()
    independent = True  # Initialize the agent as an offline one
    for t in range(num_sim):
        time_one_step = -time.time()
        num_updates = 0
        previous_state = copy.deepcopy(states)
        actions, internals = agent.act(states=states, internals=internals, independent=independent)
        states, terminal, reward = env.execute(actions=actions)     # In diagram, it is "Model" in "Safety Supervisor"
        """ If the state is outside of CIS """
        # TODO: this if condition is simply as the same as agent.observe. This is only 1 step forward
        if reward < 0:  # TODO: may use a better condition since we may not use 10000 in the reward design, or always have <0 or >0 conditions for all reward functions
            notInCIS = True  # Set the state is not in CIS
            retrain = 0
            num_updates = 0
            whileloop = 0
            reset_env_online = 1    # every "reset_env_online" step, we reset the environment
            independent = False
            tried_actions_rl = []
            tried_states_rl = []
            tried_rewards_rl = []
            print('---------------------------------------------------------------------------')
            print(f'At episode {_}, step {t}')
            while notInCIS == True:
                retrain_accumulate += 1
                retrain += 1
                update_index = 1
                # print('Retrain', retrain, 'times.', 'In total retrain', retrain_accumulate, 'times')

                if num_updates < max_online_updates:
                    # The following if-else corresponds to "Retrain with (xk,uk,rk+1,xk+1)"
                    if num_updates < online_training_act_and_observe and notInCIS == True:
                        # Retrain
                        tried_actions_episode = []
                        tried_states_episode = [previous_state]
                        tried_rewards_episode = []
                        if whileloop % reset_env_online == 0:   # if reset_env_online=200, then we use the same approach as offline training
                            env.actual_reset = False
                            states_redundant = env.reset()  #todo: one important change. make "_timestep" be 0.
                            whileloop = 0
                        # Reset state back to previous unsafe state.
                        # If we put this line out of if loop, then every step we start with unsafe state.
                        # If we put it under the if loop, then only the initial state is the unsafe state.
                        # The second approach is not doable, because when the initial state is unsafe,
                        # the following states may violate state boundaries given in RL environment.)
                        env.current_state = copy.deepcopy(previous_state)
                        actions = agent.act(states=previous_state, independent=independent)
                        states, terminal, reward = env.execute(actions=actions)     # Here, the action could be safe or unsafe
                        whileloop += 1
                        if whileloop == reset_env_online:   # When reset_env_online=1, it means each episode only has 1 step. So we terminate the episode after 1 step.
                            terminal = True
                        update_index = agent.observe(terminal=terminal, reward=reward)  # Observes reward and whether a terminal state is reached, needs to be preceded by act()
                        num_updates += update_index
                        # print('Number of updates is', num_updates)
                        tried_actions_episode.append(actions['Tc'])
                        tried_states_episode.append(states)
                        tried_rewards_episode.append(reward)
                        tried_actions_rl.append(tried_actions_episode)
                        tried_states_rl.append(tried_states_episode)
                        tried_rewards_rl.append(tried_rewards_episode)

                    else:
                        if num_updates == online_training_act_and_observe:
                            episode_states, episode_actions, episode_terminal, episode_reward = env.backup_findSafeU(previous_state)    # todo: no next state?
                        if len(episode_states) == 0:
                            print('This state does not belong to CIS because there is no safe action. CIS is wrong')
                            print(f'At episode {_}, step {t}. This state is {previous_state}')
                            env.examine_state(previous_state)
                            # env.plot_state(previous_state)
                            quit()
                        else:
                            # Feed recorded experience to agent
                            agent.experience(
                                states=episode_states, actions=episode_actions,
                                terminal=episode_terminal, reward=episode_reward
                            )

                            # Perform update
                            agent.update()
                            update_index = 1
                            num_updates += update_index
                            # print('Number of updates is', num_updates)

                    # Examine retrained RL. In diagram, it's corresponding to "If xk+1 in CIS"
                    if update_index == 1:
                        print('Number of updates is', num_updates)
                        # print('Testing updated RL')
                        env.actual_reset = False
                        states_redundant = env.reset()  #todo:one important change
                        env.current_state = copy.deepcopy(previous_state)   # reset state back to previous unsafe state
                        actions = agent.act(states=previous_state, independent=True)  # In this examine, we only run act in independent mode for 1 step. Will this affect the memory? No, it does not
                        states, terminal, reward = env.execute(actions=actions)

                        if reward >= 0:  # TODO: may use a better condition. This one is the opposite of the previous one
                            # print('Testing succeed')
                            env.actual_reset = False
                            states_redundant = env.reset()  # todo: one important change. After retrain, reset to reset retrain steps
                            env.current_state = copy.deepcopy(states)   # reset state to safe state
                            notInCIS = False  # Exit the while loop
                            independent = True  # Make an offline agent again
                        #     print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')
                        # else:
                        #     print('Testing failed')
                else:   # Backup table
                    print('Cannot find the safe action, using the backup table.')
                    # backup table
                    env.actual_reset = False
                    states_redundant = env.reset()  # todo:one important change
                    env.current_state = copy.deepcopy(previous_state)   # reset state back to previous unsafe state
                    actions = env.backup_findClosestU(previous_state, episode_actions, unsafe=actions)  # In this examine, we only run act in independent mode for 1 step. Will this affect the memory? No, it does not
                    states, terminal, reward = env.execute(actions=actions)
                    env.current_state = copy.deepcopy(states)   # todo: maybe redundant
                    notInCIS = False  # Exit the while loop
                    independent = True  # Make an offline agent again
                    reset_x += 1    # We do not giveup anymore.
                    # print('^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^')

        time_one_step += time.time()
        if num_updates > 0:
            print('Episode {}, step {}. Used {:.2f} seconds'.format(_,t,time_one_step))
            time_one_failed_step = [_, t, num_updates, time_one_step]
            time_failed_step_history.append(time_one_failed_step)
        """ Finish one step """
        score += reward
        rewards_episode.append(reward)
        actions_episode.append(actions['Tc'])
        states_episode.append(states)
        bad_state_transitions_episode.append([_, t, previous_state, actions['Tc'], reward, states])
        if giveup == True:
            print('Leaving the episode', _)
            giveup = False
            reset_x += 1
            break

        if (np.matmul(env.hrepABig, states) > env.hrepBBig + env.tol).any():  # This also include the case that states is violating the physical constraint
            # score -= reward
            # score += -1000
            # states = copy.deepcopy(previous_state)
            # env.current_state = copy.deepcopy(previous_state)
            # reset_x += 1
            # sys.exit('The state should not be outside of CIS. Check code.')
            print('This final state is outside of CIS, yet the optimization said the action was safe')
            print(f'At episode {_}, step {t}. This previous state is {previous_state}, the action is {actions["Tc"]} and the final state is {states}')
            env.examine_state(states)
            # env.plot_state(states)
            optimization_wrong += 1
            break

    """ Finish one episode """
    if reset_x != 0 or optimization_wrong != 0:  # Resetting x happened
        bad_epi += 1
        bad_state_transitions_history.append(bad_state_transitions_episode)
    else:
        bad_state_transitions_history.append([])
    time_one_episode += time.time()
    print('{} bad episodes out of {} episodes. Resetted {} times. Optimization was wrong {} times. This episode taks {:.2f} secs'.format(bad_epi, _+1, reset_x, optimization_wrong, time_one_episode))
    score_history.append(score)
    fail_history.append(reset_x)
    # if _%100 == 0:
    rewards_history.append(rewards_episode)
    actions_history.append(actions_episode)
    states_history.append(states_episode)
    time_history.append(time_one_episode)

print('Bad episode results (online implementation):', bad_epi, 'out of', _+1)
# agent.close()  # The line will delete the model inside the agent, hence we could not save it anymore
# env.close()

""" Save the agent and the reward trajectory """
fname = agent.spec['agent'] + '_cstr_' + '0p0001' + 'lr_' + str(n_episodes) + 'episodes'
respth = result_path+'/seed_{}_train/'
dir = respth.format(args.seed, args.seed)
try:
    os.mkdir(dir)
except OSError as error:
    print(error)
scores_file = dir + 'scores_' + fname
rewards_file = dir + 'rewards_' + fname
actions_file = dir + 'actions_' + fname
states_file = dir + 'states_' + fname
fail_file = dir + 'fail_' + fname
bad_states_file = dir + 'bad_states_' + fname
time_file = dir + 'time_' + fname
time_failed_step_file = dir + 'time_failed_step' + fname
np.save(scores_file, score_history)
np.save(rewards_file, rewards_history)
np.save(actions_file, actions_history)
np.save(states_file, states_history)
np.save(bad_states_file, bad_state_transitions_history)
np.save(time_file, time_history)
np.save(time_failed_step_file, time_failed_step_history)
agent.save(dir, filename=fname)
